{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ee7c2f43-1333-4061-bc3a-6d223acef12c",
   "metadata": {},
   "source": [
    "This notebook can be used to test the state dict mapping between the original MDETR repo's checkpoints (with ResNet backbone, found [here](https://github.com/ashkamath/mdetr#pre-training)) and the refactored TorchMultimodal classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e9c13575-2049-4b20-9ae3-5fd9b93109e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "download_dir = \"/data/home/ebs/data/mdetr\"\n",
    "repo_dir = \"/data/home/ebs\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6798f6c6-5004-4ee8-a98f-4edf25410185",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install MDETR repo \n",
    "!git clone https://github.com/ashkamath/mdetr.git $repo_dir\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a09c61-cd5f-47fc-9552-2b23533a2a8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download checkpoint\n",
    "!wget https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth?download=1 -P $download_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dff97610-8c5e-4f06-8aee-11baa3f52a60",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/home/ebs/miniconda3/envs/mdetr-notebook/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/data/home/ebs/miniconda3/envs/mdetr-notebook/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/data/home/ebs/miniconda3/envs/mdetr-notebook/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet101_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet101_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n",
      "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']\n",
      "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import sys \n",
    "sys.path.append(repo_dir)\n",
    "sys.path.append(os.path.join(repo_dir,\"mdetr\"))\n",
    "\n",
    "# Load MDETR classes and ResNet101 weights\n",
    "import torch\n",
    "from torch import nn\n",
    "from mdetr.models import build_model\n",
    "from mdetr.main import get_args_parser\n",
    "import argparse\n",
    "\n",
    "mdetr = torch.load(os.path.join(download_dir,\"pretrained_resnet101_checkpoint.pth?download=1\"), map_location=torch.device('cpu'))\n",
    "\n",
    "parser = argparse.ArgumentParser(\"DETR training and evaluation script\", parents=[get_args_parser()])\n",
    "\"--dataset_config=wef --device=cpu\"\n",
    "args = parser.parse_args(['--dataset_config', 'wef', '--device', 'cpu'])\n",
    "model, criterion, contrastive_criterion, qa_criterion, weight_dict = build_model(args)\n",
    "\n",
    "model.load_state_dict(mdetr['model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0d45b440-3021-4a6c-9069-51277919e6b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn, Tensor\n",
    "from typing import Dict, List\n",
    "\n",
    "# Define a couple helper functions\n",
    "def filter_dict(key_condition, d):\n",
    "    return {k: v for k, v in d.items() if key_condition(k)}\n",
    "\n",
    "def get_params_for_layer(state_dict, i):\n",
    "    return [x for x in state_dict.keys() if f'layer.{i}.' in x or f'layers.{i}' in x]\n",
    "\n",
    "# Mapping from TorchText layers to Hugging Face ones\n",
    "# TorchText's input projection should equal the concatenation of \n",
    "# Hugging Face's Q,K,V matrics\n",
    "param_mapping = {\n",
    "    'self_attn.in_proj_weight': ['attention.self.query', 'attention.self.key', 'attention.self.value'],\n",
    "    'self_attn.in_proj_bias': ['attention.self.query', 'attention.self.key', 'attention.self.value'], \n",
    "    'self_attn.out_proj': 'attention.output.dense',\n",
    "    'norm1': 'attention.output.LayerNorm',\n",
    "    'linear1': 'intermediate.dense',\n",
    "    'linear2': 'output.dense',\n",
    "    'norm2': 'output.LayerNorm',\n",
    "}\n",
    "\n",
    "# These are the prefixes of the text encoder layers as they occur in Hugging Face and TorchText\n",
    "hf_layer_prefix = 'transformer.text_encoder.encoder.layer'\n",
    "tt_layer_prefix = 'text_encoder.encoder.layers.layers'\n",
    "\n",
    "postfixes = ['weight', 'bias']\n",
    "\n",
    "# Create a state dict for ith layer of TorchText RoBERTa encoder\n",
    "# for storing weights from ith layer of Hugging Face's encoder\n",
    "def map_layer(hf_state_dict, tt_state_dict, i):\n",
    "    mapped_state_dict = {}\n",
    "    hf_layer = get_params_for_layer(hf_state_dict, i)\n",
    "    tt_layer = get_params_for_layer(tt_state_dict, i)\n",
    "    for tt_key_short, hf_key_short in param_mapping.items():\n",
    "        tt_key_short = '.'.join([tt_layer_prefix, str(i), tt_key_short])\n",
    "        # For Q,K,V matrices we need to concat the weights\n",
    "        if isinstance(hf_key_short, List):\n",
    "            hf_keys_short = list(map(lambda x: '.'.join([hf_layer_prefix, str(i), x]), hf_key_short))\n",
    "            # for postfix in postfixes:\n",
    "            postfix = tt_key_short.split('_')[-1]\n",
    "            hf_keys = ['.'.join([x, postfix]) for x in hf_keys_short]\n",
    "            if not any([x in tt_key_short for x in postfixes]):\n",
    "                tt_key = '.'.join([tt_key_short, postfix])\n",
    "            else:\n",
    "                tt_key = tt_key_short\n",
    "            # print(f\"COMBINING {hf_keys}\")\n",
    "            qkv_combined = torch.concat([hf_state_dict[hf_key] for hf_key in hf_keys])\n",
    "            # print(f\"qkv_combined size is {qkv_combined.size()}\")\n",
    "            # print(f\"Mapping into {tt_key}\")\n",
    "            mapped_state_dict[tt_key] = qkv_combined\n",
    "        else:\n",
    "            hf_key_short = '.'.join([hf_layer_prefix, str(i), hf_key_short])\n",
    "            for postfix in postfixes:\n",
    "                tt_key = '.'.join([tt_key_short, postfix])\n",
    "                hf_key = '.'.join([hf_key_short, postfix])\n",
    "                mapped_state_dict[tt_key] = hf_state_dict[hf_key]\n",
    "\n",
    "    return mapped_state_dict\n",
    "\n",
    "    \n",
    "# Just a for loop around the text encoder layer mapping\n",
    "def map_text_encoders(hf_state_dict: Dict[str, Tensor], tt_state_dict: Dict[str, Tensor], n_layers: int = 12):\n",
    "    mapped_state_dict = {}\n",
    "    for i in range(n_layers):\n",
    "        mapped_state_dict.update(map_layer(hf_state_dict, tt_state_dict, i))\n",
    "    return mapped_state_dict\n",
    "\n",
    "\n",
    "# The main function used to map from the MDETR state dict to the TorchMultimodal one\n",
    "# TODO: refactor to remove the explicit dependency on n_layers\n",
    "def map_mdetr_state_dict(mdetr_state_dict, mm_state_dict, n_layers: int = 12): \n",
    "    # Perform the text encoder mapping\n",
    "    mapped_state_dict = map_text_encoders(\n",
    "        mdetr_state_dict, \n",
    "        mm_state_dict,\n",
    "        n_layers=12\n",
    "    )\n",
    "    \n",
    "    # Miscellaneous renaming (this can probably be cleaned up)\n",
    "    mapped_state_dict = {k.replace('transformer.text_encoder', 'text_encoder'): v for k, v in mapped_state_dict.items() if 'embeddings' not in k}\n",
    "\n",
    "    for k, v in mdetr_state_dict.items():\n",
    "        if not k.startswith('transformer.text_encoder') and not k.startswith('transformer.resizer') and 'input_proj' not in k:\n",
    "            mapped_state_dict[k.replace('backbone.0', 'image_backbone')] = v\n",
    "        if 'embeddings' in k:\n",
    "            mapped_state_dict[k.replace('transformer.','')] = v\n",
    "        if 'input_proj' in k:\n",
    "            mapped_state_dict[k.replace('input_proj','image_projection')] = v\n",
    "        if 'resizer' in k:\n",
    "            mapped_state_dict[k.replace('transformer.','').replace('resizer', 'text_projection')] = v\n",
    "        if 'embeddings.LayerNorm' in k:\n",
    "            new_k = k.replace('transformer.','')\n",
    "            mapped_state_dict[new_k.replace('LayerNorm', 'layer_norm')] = v\n",
    "            del mapped_state_dict[new_k]\n",
    "            # mapped_state_dict[f\"text_encoder.encoder.embedding_layer_norm.{k.split('.')[-1]}\"] = v\n",
    "        if 'bbox_embed' in k:\n",
    "            parsed = k.split('.')\n",
    "            i = int(parsed[parsed.index('layers') + 1])\n",
    "            mapped_state_dict[k.replace('layers','model').replace(str(i), str(2*i))] = v\n",
    "            del mapped_state_dict[k]\n",
    "        if all([x in k for x in ['transformer', 'layers', 'linear']]):\n",
    "            k_split = k.split('.')\n",
    "            i = int(k_split[-2][-1])\n",
    "            k_new = '.'.join(k_split[:-2] + [\"mlp\", \"model\", str(3*(i-1)), k_split[-1]])\n",
    "            mapped_state_dict[k_new] = v\n",
    "            del mapped_state_dict[k]\n",
    "        if 'contrastive' in k:\n",
    "            k_new = k.replace('align','alignment').replace('projection_image', 'image_projection').replace('projection_text', 'text_projection')\n",
    "            mapped_state_dict[k_new] = v\n",
    "            del mapped_state_dict[k]\n",
    "    \n",
    "    return mapped_state_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ced58d3-f4a2-4fc5-86c2-5cd8c6f49d6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, Tensor\n",
    "from mdetr.models.mdetr import MDETR\n",
    "from mdetr.models.transformer import Transformer\n",
    "import unittest\n",
    "# from torchmultimodal.utils.common import NestedTensor\n",
    "from transformers import RobertaTokenizerFast\n",
    "\n",
    "from torchmultimodal.models.mdetr.image_encoder import mdetr_resnet101_backbone\n",
    "from torchmultimodal.models.mdetr.text_encoder import mdetr_roberta_text_encoder\n",
    "from torchmultimodal.models.mdetr.transformer import MDETRTransformer as mm_Transformer\n",
    "from torchmultimodal.models.mdetr.model import mdetr_resnet101\n",
    "\n",
    "max_diff = lambda x, y: torch.max(torch.abs(x - y))\n",
    "\n",
    "# This is the class for testing the state dict mapping\n",
    "class TestMDETR(unittest.TestCase):\n",
    "\n",
    "    def setUp(self):\n",
    "        self.test_tensors = torch.rand(2, 3, 64, 64).unbind(dim=0)\n",
    "        mask = torch.randint(0, 2, (2, 64, 64))\n",
    "        # self.samples = NestedTensor(test_tensor, mask)\n",
    "        self.captions = ['I can see the sun', 'But even if I cannot see the sun, I know that it exists']\n",
    "        self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')\n",
    "        self.text = self.tokenizer.batch_encode_plus(self.captions, padding=\"longest\", return_tensors=\"pt\")\n",
    "        self.mdetr = model\n",
    "        self.mdetr.eval()\n",
    "  \n",
    "    def run_mdetr(self):\n",
    "        self.memory_cache = self.mdetr(self.test_tensors, self.captions, encode_and_save=True)\n",
    "        self.mdetr_out = self.mdetr(self.test_tensors, self.captions, encode_and_save=False, memory_cache=self.memory_cache)\n",
    "\n",
    "        \n",
    "    def run_mm_mdetr(self):\n",
    "        self.mm_mdetr = mdetr_resnet101()\n",
    "        self.mapped_state_dict = map_mdetr_state_dict(self.mdetr.state_dict(), self.mm_mdetr.state_dict())\n",
    "\n",
    "        \n",
    "        self.mm_mdetr.load_state_dict(self.mapped_state_dict)\n",
    "        self.mm_mdetr.eval()\n",
    "        self.mm_out = self.mm_mdetr(self.test_tensors, self.text.input_ids)\n",
    "        self.mm_out_dict = {\n",
    "            'pred_logits': self.mm_out.pred_logits, \n",
    "            'pred_boxes': self.mm_out.pred_boxes, \n",
    "            'proj_queries': self.mm_out.projected_queries,\n",
    "            'proj_tokens': self.mm_out.projected_tokens\n",
    "            \n",
    "        }\n",
    "    def compare_results(self):\n",
    "        for k in self.mm_out_dict.keys():\n",
    "            tensor_diff = max_diff(self.mm_out_dict[k], self.mdetr_out[k])\n",
    "            print(f\"Maximum difference in {k} is {tensor_diff}\")\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "41bd2387-1b23-4a8e-a852-93cc49e8ae1b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/home/ebs/mdetr/models/position_encoding.py:41: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum difference in pred_logits is 1.239776611328125e-05\n",
      "Maximum difference in pred_boxes is 3.337860107421875e-06\n",
      "Maximum difference in proj_queries is 2.384185791015625e-07\n",
      "Maximum difference in proj_tokens is 3.129243850708008e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/home/ebs/torchmultimodal/torchmultimodal/modules/encoders/mdetr_image_encoder.py:96: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n"
     ]
    }
   ],
   "source": [
    "# Run the test\n",
    "tester = TestMDETR()\n",
    "tester.setUp()\n",
    "tester.run_mdetr()\n",
    "tester.run_mm_mdetr()\n",
    "tester.compare_results()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "241e5e0e-db39-4e95-87c5-6a02d8326fee",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('torchmm')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "af8d201ff4cb28e0071e88fa2040902e5cd6249ffbfa0955260144574496630b"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
